{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/10_train_custom_env.ipynb)\n",
    "# Train Behavior Cloning in a Custom Environment\n",
    "\n",
    "You can use `imitation` to train a policy (and, for many imitation learning algorithm, learn rewards) in a custom environment.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Define the environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use a simple ObservationMatching environment as an example. The premise is simple -- the agent receives a vector of observations, and must output a vector of actions that matches the observations as closely as possible.\n",
    "\n",
    "If you have your own environment that you'd like to use, you can replace the code below with your own environment. Make sure it complies with the standard Gym API, and that the observation and action spaces are specified correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, Optional\n",
    "from typing import Any\n",
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "\n",
    "from gymnasium.spaces import Box\n",
    "\n",
    "\n",
    "class ObservationMatchingEnv(gym.Env):\n",
    "    def __init__(self, num_options: int = 2):\n",
    "        self.state = None\n",
    "        self.num_options = num_options\n",
    "        self.observation_space = Box(0, 1, shape=(num_options,))\n",
    "        self.action_space = Box(0, 1, shape=(num_options,))\n",
    "\n",
    "    def reset(self, seed: int = None, options: Optional[Dict[str, Any]] = None):\n",
    "        super().reset(seed=seed, options=options)\n",
    "        self.state = self.observation_space.sample()\n",
    "        return self.state, {}\n",
    "\n",
    "    def step(self, action):\n",
    "        reward = -np.abs(self.state - action).mean()\n",
    "        self.state = self.observation_space.sample()\n",
    "        return self.state, reward, False, False, {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: create the environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From here, we have two options:\n",
    "- Add the environment to the gym registry, and use it with existing utilities (e.g. `make`)\n",
    "- Use the environment directly\n",
    "\n",
    "You only need to execute the cells in step 2a, or step 2b to proceed.\n",
    "\n",
    "At the end of these steps, we want to have:\n",
    "- `env`: a single environment that we can use for training an expert with SB3\n",
    "- `venv`: a vectorized environment where each individual environment is wrapped in `RolloutInfoWrapper`, that we can use for collecting rollouts with `imitation`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2a (recommended): add the environment to the gym registry"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The standard approach is adding the environment to the gym registry."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gym.register(\n",
    "    id=\"custom/ObservationMatching-v0\",\n",
    "    entry_point=ObservationMatchingEnv,  # This can also be the path to the class, e.g. `observation_matching:ObservationMatchingEnv`\n",
    "    max_episode_steps=500,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After registering, you can create an environment is `gym.make(env_id)` which automatically handles the `TimeLimit` wrapper.\n",
    "\n",
    "To create a vectorized env, you can use the `make_vec_env` helper function (Option A), or create it directly (Options B1 and B2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gymnasium.wrappers import TimeLimit\n",
    "from imitation.data import rollout\n",
    "from imitation.data.wrappers import RolloutInfoWrapper\n",
    "from imitation.util.util import make_vec_env\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n",
    "\n",
    "# Create a single environment for training an expert with SB3\n",
    "env = gym.make(\"custom/ObservationMatching-v0\")\n",
    "\n",
    "\n",
    "# Create a vectorized environment for training with `imitation`\n",
    "\n",
    "# Option A: use the `make_vec_env` helper function - make sure to pass `post_wrappers=[lambda env, _: RolloutInfoWrapper(env)]`\n",
    "venv = make_vec_env(\n",
    "    \"custom/ObservationMatching-v0\",\n",
    "    rng=np.random.default_rng(),\n",
    "    n_envs=4,\n",
    "    post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n",
    ")\n",
    "\n",
    "\n",
    "# Option B1: use a custom env creator, and create VecEnv directly\n",
    "# def _make_env():\n",
    "#     \"\"\"Helper function to create a single environment. Put any logic here, but make sure to return a RolloutInfoWrapper.\"\"\"\n",
    "#     _env = gym.make(\"custom/ObservationMatching-v0\")\n",
    "#     _env = RolloutInfoWrapper(_env)\n",
    "#     return _env\n",
    "#\n",
    "# venv = DummyVecEnv([_make_env for _ in range(4)])\n",
    "#\n",
    "# # Option B2: we can also use a parallel VecEnv implementation\n",
    "# venv = SubprocVecEnv([_make_env for _ in range(4)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Step 2b: directly use the environment\n",
    "\n",
    "Alternatively, we can directly initialize the environment by instantiating the class we created earlier, and handle all the additional logic ourselves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gymnasium.wrappers import TimeLimit\n",
    "from imitation.data import rollout\n",
    "from imitation.data.wrappers import RolloutInfoWrapper\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "import numpy as np\n",
    "\n",
    "# Create a single environment for training with SB3\n",
    "env = ObservationMatchingEnv()\n",
    "env = TimeLimit(env, max_episode_steps=500)\n",
    "\n",
    "# Create a vectorized environment for training with `imitation`\n",
    "\n",
    "\n",
    "# Option A: use a helper function to create multiple environments\n",
    "def _make_env():\n",
    "    \"\"\"Helper function to create a single environment. Put any logic here, but make sure to return a RolloutInfoWrapper.\"\"\"\n",
    "    _env = ObservationMatchingEnv()\n",
    "    _env = TimeLimit(_env, max_episode_steps=500)\n",
    "    _env = RolloutInfoWrapper(_env)\n",
    "    return _env\n",
    "\n",
    "\n",
    "venv = DummyVecEnv([_make_env for _ in range(4)])\n",
    "\n",
    "\n",
    "# Option B: use a single environment\n",
    "# env = FixedHorizonCartPoleEnv()\n",
    "# venv = DummyVecEnv([lambda: RolloutInfoWrapper(env)])  # Wrap a single environment -- only useful for simple testing like this\n",
    "\n",
    "# Option C: use multiple environments\n",
    "# venv = DummyVecEnv([lambda: RolloutInfoWrapper(ObservationMatchingEnv()) for _ in range(4)])  # Wrap multiple environments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we're just about done! Whether you used step 2a or 2b, your environment should now be ready to use with SB3 and `imitation`.\n",
    "\n",
    "For the sake of completeness, we'll train a BC model, the same way as in the first tutorial, but with our custom environment.\n",
    "\n",
    "Keep in mind that while we're using BC in this tutorial, you can just as easily use any of the other algorithms with the environment prepared in this way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.ppo import MlpPolicy\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "from gymnasium.wrappers import TimeLimit\n",
    "\n",
    "expert = PPO(\n",
    "    policy=MlpPolicy,\n",
    "    env=env,\n",
    "    seed=0,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.0,\n",
    "    learning_rate=0.0003,\n",
    "    n_epochs=10,\n",
    "    n_steps=64,\n",
    ")\n",
    "\n",
    "reward, _ = evaluate_policy(expert, env, 10)\n",
    "print(f\"Reward before training: {reward}\")\n",
    "\n",
    "\n",
    "# Note: if you followed step 2a, i.e. registered the environment, you can use the environment name directly\n",
    "\n",
    "# expert = PPO(\n",
    "#     policy=MlpPolicy,\n",
    "#     env=\"custom/ObservationMatching-v0\",\n",
    "#     seed=0,\n",
    "#     batch_size=64,\n",
    "#     ent_coef=0.0,\n",
    "#     learning_rate=0.0003,\n",
    "#     n_epochs=10,\n",
    "#     n_steps=64,\n",
    "# )\n",
    "expert.learn(10_000)  # Note: set to 100000 to train a proficient expert\n",
    "reward, _ = evaluate_policy(expert, expert.get_env(), 10)\n",
    "print(f\"Expert reward: {reward}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng()\n",
    "rollouts = rollout.rollout(\n",
    "    expert,\n",
    "    venv,\n",
    "    rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n",
    "    rng=rng,\n",
    ")\n",
    "transitions = rollout.flatten_trajectories(rollouts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms import bc\n",
    "\n",
    "bc_trainer = bc.BC(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    demonstrations=transitions,\n",
    "    rng=rng,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As before, the untrained policy only gets poor rewards:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n",
    "print(f\"Reward before training: {reward_before_training}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training, we can get much closer to the expert's performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_trainer.train(n_epochs=1)\n",
    "reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n",
    "print(f\"Reward after training: {reward_after_training}\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
